{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter('ignore')\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(60000, 2)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>filename</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>train/市管宗教活动场所.csv</td>\n",
       "      <td>文化休闲</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>train/南涧彝族自治县市汶上县残疾人联合会_残疾人证办理_.csv</td>\n",
       "      <td>医疗卫生</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>train/价格监测信息公开事项汇总信息_.xls</td>\n",
       "      <td>经济管理</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>train/县级公共图书馆文化馆总分馆制信息_大安市.csv</td>\n",
       "      <td>文化休闲</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>train/县级公共图书馆文化馆总分馆制信息_陇南市武都区.csv</td>\n",
       "      <td>文化休闲</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                              filename label\n",
       "0                   train/市管宗教活动场所.csv  文化休闲\n",
       "1  train/南涧彝族自治县市汶上县残疾人联合会_残疾人证办理_.csv  医疗卫生\n",
       "2            train/价格监测信息公开事项汇总信息_.xls  经济管理\n",
       "3       train/县级公共图书馆文化馆总分馆制信息_大安市.csv  文化休闲\n",
       "4    train/县级公共图书馆文化馆总分馆制信息_陇南市武都区.csv  文化休闲"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('raw_data/answer_train.csv')\n",
    "\n",
    "print(train.shape)\n",
    "train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(8000, 2)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>filename</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>test1/阳谷县司法所组织信息_.xls</td>\n",
       "      <td>工业</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>test1/济宁市辽宁省朝阳市龙城区县残疾人联合会_残疾人证办理_.csv</td>\n",
       "      <td>城乡建设</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>test1/企业变更信息_.xls</td>\n",
       "      <td>城乡建设</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>test1/设立经营性互联网文化单位审批_万州区.csv</td>\n",
       "      <td>生态环境</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>test1/事业企业“双随机一公开”监察检查记录_.xls</td>\n",
       "      <td>城乡建设</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                filename label\n",
       "0                  test1/阳谷县司法所组织信息_.xls    工业\n",
       "1  test1/济宁市辽宁省朝阳市龙城区县残疾人联合会_残疾人证办理_.csv  城乡建设\n",
       "2                      test1/企业变更信息_.xls  城乡建设\n",
       "3           test1/设立经营性互联网文化单位审批_万州区.csv  生态环境\n",
       "4          test1/事业企业“双随机一公开”监察检查记录_.xls  城乡建设"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test = pd.read_csv('raw_data/submit_example_test1.csv')\n",
    "\n",
    "print(test.shape)\n",
    "test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>label_n</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>工业</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>文化休闲</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>教育科技</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>医疗卫生</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>文秘行政</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>生态环境</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>城乡建设</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>农业畜牧业</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>经济管理</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>交通运输</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>政法监察</td>\n",
       "      <td>10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>财税金融</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>劳动人事</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>旅游服务</td>\n",
       "      <td>13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>资源能源</td>\n",
       "      <td>14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>商业贸易</td>\n",
       "      <td>15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>气象水文测绘地震地理</td>\n",
       "      <td>16</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>民政社区</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>信息产业</td>\n",
       "      <td>18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>外交外事</td>\n",
       "      <td>19</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         label  label_n\n",
       "0           工业        0\n",
       "1         文化休闲        1\n",
       "2         教育科技        2\n",
       "3         医疗卫生        3\n",
       "4         文秘行政        4\n",
       "5         生态环境        5\n",
       "6         城乡建设        6\n",
       "7        农业畜牧业        7\n",
       "8         经济管理        8\n",
       "9         交通运输        9\n",
       "10        政法监察       10\n",
       "11        财税金融       11\n",
       "12        劳动人事       12\n",
       "13        旅游服务       13\n",
       "14        资源能源       14\n",
       "15        商业贸易       15\n",
       "16  气象水文测绘地震地理       16\n",
       "17        民政社区       17\n",
       "18        信息产业       18\n",
       "19        外交外事       19"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_label = pd.DataFrame({'label': train.label.value_counts(normalize=True).index.tolist(),\n",
    "                         'label_n': [i for i in range(train.label.nunique())]})\n",
    "\n",
    "df_label.head(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(60000, 3)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>filename</th>\n",
       "      <th>label</th>\n",
       "      <th>label_n</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>train/市管宗教活动场所.csv</td>\n",
       "      <td>文化休闲</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>train/南涧彝族自治县市汶上县残疾人联合会_残疾人证办理_.csv</td>\n",
       "      <td>医疗卫生</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>train/价格监测信息公开事项汇总信息_.xls</td>\n",
       "      <td>经济管理</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>train/县级公共图书馆文化馆总分馆制信息_大安市.csv</td>\n",
       "      <td>文化休闲</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>train/县级公共图书馆文化馆总分馆制信息_陇南市武都区.csv</td>\n",
       "      <td>文化休闲</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                              filename label  label_n\n",
       "0                   train/市管宗教活动场所.csv  文化休闲        1\n",
       "1  train/南涧彝族自治县市汶上县残疾人联合会_残疾人证办理_.csv  医疗卫生        3\n",
       "2            train/价格监测信息公开事项汇总信息_.xls  经济管理        8\n",
       "3       train/县级公共图书馆文化馆总分馆制信息_大安市.csv  文化休闲        1\n",
       "4    train/县级公共图书馆文化馆总分馆制信息_陇南市武都区.csv  文化休闲        1"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = train.merge(df_label, on='label', how='left')\n",
    "\n",
    "print(train.shape)\n",
    "train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def clean_text(x):\n",
    "    text = x.replace('train/', '').replace('.xls', '').replace('.csv', '').replace('_', ' ')\n",
    "    return text\n",
    "\n",
    "train['text'] = train['filename'].apply(lambda x: clean_text(x))\n",
    "test['text'] = test['filename'].apply(lambda x: clean_text(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>filename</th>\n",
       "      <th>label</th>\n",
       "      <th>label_n</th>\n",
       "      <th>text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>train/市管宗教活动场所.csv</td>\n",
       "      <td>文化休闲</td>\n",
       "      <td>1</td>\n",
       "      <td>市管宗教活动场所</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>train/南涧彝族自治县市汶上县残疾人联合会_残疾人证办理_.csv</td>\n",
       "      <td>医疗卫生</td>\n",
       "      <td>3</td>\n",
       "      <td>南涧彝族自治县市汶上县残疾人联合会 残疾人证办理</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>train/价格监测信息公开事项汇总信息_.xls</td>\n",
       "      <td>经济管理</td>\n",
       "      <td>8</td>\n",
       "      <td>价格监测信息公开事项汇总信息</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>train/县级公共图书馆文化馆总分馆制信息_大安市.csv</td>\n",
       "      <td>文化休闲</td>\n",
       "      <td>1</td>\n",
       "      <td>县级公共图书馆文化馆总分馆制信息 大安市</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>train/县级公共图书馆文化馆总分馆制信息_陇南市武都区.csv</td>\n",
       "      <td>文化休闲</td>\n",
       "      <td>1</td>\n",
       "      <td>县级公共图书馆文化馆总分馆制信息 陇南市武都区</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                              filename label  label_n  \\\n",
       "0                   train/市管宗教活动场所.csv  文化休闲        1   \n",
       "1  train/南涧彝族自治县市汶上县残疾人联合会_残疾人证办理_.csv  医疗卫生        3   \n",
       "2            train/价格监测信息公开事项汇总信息_.xls  经济管理        8   \n",
       "3       train/县级公共图书馆文化馆总分馆制信息_大安市.csv  文化休闲        1   \n",
       "4    train/县级公共图书馆文化馆总分馆制信息_陇南市武都区.csv  文化休闲        1   \n",
       "\n",
       "                        text  \n",
       "0                   市管宗教活动场所  \n",
       "1  南涧彝族自治县市汶上县残疾人联合会 残疾人证办理   \n",
       "2            价格监测信息公开事项汇总信息   \n",
       "3       县级公共图书馆文化馆总分馆制信息 大安市  \n",
       "4    县级公共图书馆文化馆总分馆制信息 陇南市武都区  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(60000, 2)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>市管宗教活动场所</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>南涧彝族自治县市汶上县残疾人联合会 残疾人证办理</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>价格监测信息公开事项汇总信息</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>县级公共图书馆文化馆总分馆制信息 大安市</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>县级公共图书馆文化馆总分馆制信息 陇南市武都区</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                        text  label\n",
       "0                   市管宗教活动场所      1\n",
       "1  南涧彝族自治县市汶上县残疾人联合会 残疾人证办理       3\n",
       "2            价格监测信息公开事项汇总信息       8\n",
       "3       县级公共图书馆文化馆总分馆制信息 大安市      1\n",
       "4    县级公共图书馆文化馆总分馆制信息 陇南市武都区      1"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df = train[['text', 'label_n']]\n",
    "train_df.columns = ['text', 'label']\n",
    "\n",
    "print(train_df.shape)\n",
    "train_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = train_df.sample(frac=1., random_state=1024)\n",
    "\n",
    "eval_df = train_df[54000:]\n",
    "train_df = train_df[:54000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "wandb: WARNING W&B installed but not logged in.  Run `wandb login` or set the WANDB_API_KEY env variable.\n"
     ]
    }
   ],
   "source": [
    "from simpletransformers.classification import ClassificationModel, ClassificationArgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_args = ClassificationArgs()\n",
    "\n",
    "model_args.max_seq_length = 128\n",
    "model_args.train_batch_size = 16\n",
    "model_args.num_train_epochs = 3\n",
    "model_args.fp16 = False\n",
    "model_args.evaluate_during_training = True\n",
    "model_args.overwrite_output_dir = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at hfl/chinese-roberta-wwm-ext and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model = ClassificationModel(\"bert\", \n",
    "                            \"hfl/chinese-roberta-wwm-ext\",\n",
    "                            num_labels=len(df_label),\n",
    "                            args=model_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9d72423f7bba48d3b72c51457b972abe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=54000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ecdbfc54d8414879b784c5c564f8639b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "80be35f3e242480eb4cb5cc1b93543a2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 3', max=3375.0, style=ProgressStyle(de…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "52db1ec47ece4cb2837e47d3ae2e1773",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 3', max=3375.0, style=ProgressStyle(de…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b321c73d86914ec3a63c3e671bb91b74",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 3', max=3375.0, style=ProgressStyle(de…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(10125,\n",
       " {'global_step': [2000, 3375, 4000, 6000, 6750, 8000, 10000, 10125],\n",
       "  'mcc': [0.9476937708528343,\n",
       "   0.9527763533978135,\n",
       "   0.9577338457403467,\n",
       "   0.9677364423372926,\n",
       "   0.9699235225996395,\n",
       "   0.9722837477080647,\n",
       "   0.9746596788441761,\n",
       "   0.9748424057549449],\n",
       "  'train_loss': [0.15783660113811493,\n",
       "   1.1505966186523438,\n",
       "   0.01191699504852295,\n",
       "   0.0014483332633972168,\n",
       "   0.001240551471710205,\n",
       "   0.11167295277118683,\n",
       "   0.00037664175033569336,\n",
       "   0.0004429817199707031],\n",
       "  'eval_loss': [0.2014222281773885,\n",
       "   0.1762613417506218,\n",
       "   0.1625722239414851,\n",
       "   0.12617032970984776,\n",
       "   0.11873192139466604,\n",
       "   0.11979165108998617,\n",
       "   0.10948107000192006,\n",
       "   0.10929350209236145]})"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.train_model(train_df, eval_df=eval_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9e171447a92945c18f07adc9ed897d16",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=6000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cccdf11cf348403ea20c2bbb6c16df00",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=750.0, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'mcc': 0.9748424057549449, 'acc': 0.977, 'eval_loss': 0.10929350209236145}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "result, _, _ = model.eval_model(eval_df, acc=accuracy_score)\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = []\n",
    "for i, row in test.iterrows():\n",
    "    data.append(row['text'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "03069c315c4d4bcca1e72c514e0055ae",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=8000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8d6907f48ae6412ca9d753f9ef6f2132",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "predictions, raw_outputs = model.predict(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(8000, 2)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>filename</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>test1/阳谷县司法所组织信息_.xls</td>\n",
       "      <td>政法监察</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>test1/济宁市辽宁省朝阳市龙城区县残疾人联合会_残疾人证办理_.csv</td>\n",
       "      <td>医疗卫生</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>test1/企业变更信息_.xls</td>\n",
       "      <td>工业</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>test1/设立经营性互联网文化单位审批_万州区.csv</td>\n",
       "      <td>文化休闲</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>test1/事业企业“双随机一公开”监察检查记录_.xls</td>\n",
       "      <td>劳动人事</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>test1/昭阳区市广东省名牌产品工业类称号信息_.xls</td>\n",
       "      <td>工业</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>test1/市级民办职业培训学校设立审批_.xls</td>\n",
       "      <td>教育科技</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>test1/分县市、区建筑业生产状况_.xls</td>\n",
       "      <td>城乡建设</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>test1/辽宁省阜新市彰武县市国控污染物自动监控信息_.xls</td>\n",
       "      <td>生态环境</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>test1/水环境重点排污公司名单目录信息_.xls</td>\n",
       "      <td>生态环境</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                filename label\n",
       "0                  test1/阳谷县司法所组织信息_.xls  政法监察\n",
       "1  test1/济宁市辽宁省朝阳市龙城区县残疾人联合会_残疾人证办理_.csv  医疗卫生\n",
       "2                      test1/企业变更信息_.xls    工业\n",
       "3           test1/设立经营性互联网文化单位审批_万州区.csv  文化休闲\n",
       "4          test1/事业企业“双随机一公开”监察检查记录_.xls  劳动人事\n",
       "5          test1/昭阳区市广东省名牌产品工业类称号信息_.xls    工业\n",
       "6              test1/市级民办职业培训学校设立审批_.xls  教育科技\n",
       "7                test1/分县市、区建筑业生产状况_.xls  城乡建设\n",
       "8       test1/辽宁省阜新市彰武县市国控污染物自动监控信息_.xls  生态环境\n",
       "9             test1/水环境重点排污公司名单目录信息_.xls  生态环境"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sub = pd.read_csv('raw_data/submit_example_test1.csv')[['filename']]\n",
    "sub['label_n'] = predictions\n",
    "sub = pd.merge(sub, df_label, on='label_n', how='left')\n",
    "sub.drop(['label_n'], axis=1, inplace=True)\n",
    "\n",
    "print(sub.shape)\n",
    "sub.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub.to_csv('baseline_202011121514.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
